from langchain_core.tools import tool
from langchain_core.prompts import ChatPromptTemplate
from operator import itemgetter
from langchain_core.output_parsers import JsonOutputParser
from langchain.tools.render import render_text_description
from langdev_helper.llm.qwen import llm

@tool
def multiply(first_int: int, second_int: int) -> int:
    """将两个整数相乘。"""
    return first_int * second_int

@tool
def add(first_int: int, second_int: int) -> int:
    "将两个整数相加。"
    return first_int + second_int

@tool
def exponentiate(base: int, exponent: int) -> int:
    "对底数求指数幂。"
    return base**exponent

tools = [add, exponentiate, multiply]

def tool_chain(model_output):
    tool_map = {tool.name: tool for tool in tools}
    chosen_tool = tool_map[model_output["name"]]
    return itemgetter("arguments") | chosen_tool

rendered_tools = render_text_description(tools)

system_prompt = f"""您是一名助理，可以使用以下工具集。 以下是每个工具的名称和说明:

{rendered_tools}

根据用户输入，返回要使用的工具的名称和输入。 将您的响应作为带有'name'和'arguments'键的 JSON blob 返回，“arguments”键对应的值应该是所选函数的输入参数的字典，字典里不要有任何说明,此JSON blob必须是如下格式：```json
...
```"""

prompt = ChatPromptTemplate.from_messages(
    [("system", system_prompt), ("user", "{input}")]
)


chain = prompt | llm | JsonOutputParser() | tool_chain
chain1= prompt | llm | JsonOutputParser()


ret = chain1.invoke({"input": "3加3等于"})

print(ret)


